#!/usr/bin/env python3
# -*-Python-*-

import re, sys, os
from subprocess import Popen, STDOUT, PIPE, DEVNULL, \
                SubprocessError, TimeoutExpired
from threading import Thread
from queue import Queue, Empty, Full
from getopt import getopt, GetoptError
from os.path import basename, join, splitext

MAX_ERROR_LINES = 3000
MAX_OUTPUT_LINES = 3000

SHORT_WAIT = 5
QUANTUM = 0.05
EOS = object()

WIN_PATN = re.compile(r'\s*\*\s*((?:Black|White)\s+wins\.)\s*$')
MOVE_PATN = re.compile(r'\s*\*\s*(?P<move>[a-i][1-9]-[a-i1-9])')
MSG_PATN = re.compile(r'(?P<win>{})|{}'.format(WIN_PATN.pattern,
                                               MOVE_PATN.pattern))
def Usage():
    print("Usage: test-tablut [ --verbose ] PROG1-SCRIPT.in "
          "[ PROG2-SCRIPT.in ]", file=sys.stderr)
    sys.exit(1)

def no_except(func):
    try:
        func()
    except:
        pass

class Terminate(BaseException):
    pass

def queue_printer(queue, dest):
    def runner():
        try:
            while True:
                v = queue.get()
                if v is EOS:
                    return
                dest.write(v)
        except ValueError:
            pass
    th = Thread(target=runner, daemon=True)
    th.start()

def get_run_command(inp, name):
    while True:
        line = inp.readline()
        if line == '':
            print("Error: Could not find initial command line in", name,
                  file=sys.stderr)
            sys.exit(1)
        mat = re.match(r'\s*#\*\s*(.*?)\s*$', line)
        if mat:
            return mat.group(1)

class Prog:

    def __init__(self, command, id,
                 commands_in, error_dest, output_dest,
                 logging_queue=None):
        self._id = id
        self._end_message = None
        try:
            self._proc = Popen(re.split(r'\s', command),
                               stdin=PIPE, stdout=PIPE, stderr=PIPE)
        except FileNotFoundError:
            self._end_message = "could not execute " + command
            print(self._end_message, file=error_dest)
            self._proc = None
            return
        self._proc_msg_queue = Queue(100)
        self._move_queue = Queue(100)
        self._output_dest = output_dest
        self._error_dest = error_dest
        self._other_prog = None
        self._commands_in = commands_in
        self._logging_queue = logging_queue
        self._output_thread = Thread(target=self._output_thread_runner)
        self._output_thread.start()
        self._error_thread = Thread(target=self._error_thread_runner)
        self._error_thread.start()
        self._control_thread = None
        self._move_time_limit = 10
        self._game_time_limit = 60

    def start(self):
        if self._proc:
            self._control_thread = Thread(target=self._controller_thread_runner)
            self._control_thread.start()

    def join(self):
        if self._proc and self._control_thread:
            self._control_thread.join()

    def failed(self):
        return self._proc is None

    def set_other(self, other_prog):
        self._other_prog = other_prog

    def end_message(self):
        return self._end_message

    def _send_command(self, command):
        if not re.match('.*\n', command):
            command += "\n"
        self._log(command, "<")
        try:
            self._proc.stdin.write(bytes(command, encoding="ascii",
                                         errors="ignore"))
            self._proc.stdin.flush()
        except:
            pass

    def _controller_thread_runner(self):
        try:
            for line in self._commands_in:
                mat = re.match(r'\s*#\*\s*(.*?)\s*$', line)
                if mat:
                    mat = re.match(r'(move/win(\+?))'
                                   r'|(remote\s+move/win(\+?))'
                                   r'|(move)'
                                   r'|(win\+)'
                                   r'|time\s+([\d.]+)\s+([\d/]+)',
                                   mat.group(1))
                    if mat is None:
                        self._proc.stdin.close()
                        self._error_exit("Invalid command in testing file: {}"
                                         .format(line.rstrip()))
                        break
                    if mat.group(1):
                        self._local_game(mat.group(2))
                    elif mat.group(3):
                        self._remote_game(mat.group(4))
                    elif mat.group(5):
                        self._time_remaining = self._game_time_limit
                        self._our_move(win_allowed=False)
                    elif mat.group(6):
                        self._win()
                    elif mat.group(7):
                        self._set_times(float(mat.group(7)), float(mat.group(8)))
                else:
                    self._log(line, "<")
                    self._proc.stdin.write(line.encode(encoding='ascii'))
                    self._proc.stdin.flush()
            no_except(lambda: self._proc.stdin.close())
        except Terminate:
            pass
        self.stop()

    def _error_thread_runner(self):
        count = 0
        for line in self._proc.stderr:
            count += 1
            self._log(line, "E>")
            line = line.decode(encoding='ascii', errors='ignore')
            self._error_dest.write(line)
            mat = re.search(r'Exception in thread ".*"\s+(.*)', line)
            if mat:
                self._end_message = self._end_message \
                    or "terminated with " + mat.group(1)
            if count > MAX_ERROR_LINES:
                try:
                    self._proc.stderr.close()
                except:
                    pass
                self._end_message = self._end_message \
                    or "too much error output"
                return

    def _output_thread_runner(self):
        count = 0
        dumping = False
        for line in self._proc.stdout:
            count += 1
            if count > MAX_OUTPUT_LINES:
                try:
                    self._proc.stdout.close()
                except:
                    pass
                self._end_message = self._end_message \
                    or "too much output"
                return
            line = line.decode(encoding='ascii', errors='ignore')
            self._log(line, ">")
            line = re.sub(r'^.*> *', '', line)
            if dumping:
                self._output_dest.write(line)
                if re.match(r'===', line):
                    dumping = False
                continue
            if re.match(r'\s*\*', line) and self._end_message is None:
                if not self._enqueue(self._proc_msg_queue, line):
                    try:
                        self._proc.stdout.close()
                    except:
                        pass
                    return
            elif re.match(r'===', line):
                self._output_dest.write(line)
                dumping = True
        self._output_dest.close()
        self._enqueue(self._proc_msg_queue, EOS)

    def _our_move(self, win_allowed=True):
        msg = self._timed_get(self._proc_msg_queue, 'waiting for my move')
        if msg is EOS:
            raise Terminate
        msg = re.sub(' +', ' ', msg)
        mat = MSG_PATN.match(msg)
        if not mat:
            self._error_exit("malformed move or win message: {}".format(msg))
        if self._other_prog:
            self._other_prog.receive_move(msg)
        if mat.group('win'):
            if win_allowed:
                return msg
            else:
                self._error_exit("unexpected win message")
        else:
            return msg

    def _remote_move(self):
        msg = self._timed_get(self._move_queue, 'waiting for opponent')
        if msg is EOS:
            raise Terminate
        msg = re.sub(' +', ' ', msg)
        mat = MSG_PATN.match(msg)
        if not mat:
            self._error_exit("malformed move or win message: {}".format(msg))
        return msg

    def _local_game(self, print_win):
        self._time_remaining = self._game_time_limit        
        while True:
            msg = self._our_move()
            if msg is EOS:
                self.stop()
                raise Terminate
            mat = MSG_PATN.match(msg)
            if mat.group('win'):
                if print_win:
                    print(msg.rstrip(), file=self._output_dest)
                self._time_remaining = self._move_time_limit
                return

    def _remote_game(self, print_win):
        if self._other_prog is None:
            self._error_exit("invalid remote command in testing file"
                             " (no opponent)")
        self._time_remaining = self._game_time_limit
        prev_win = None
        while True:
            rmsg = self._remote_move()
            if rmsg is EOS:
                self.stop("remote game terminated")
                raise Terminate
            rmat = MSG_PATN.match(rmsg)
            curr_win = rmat.group('win')
            if rmat.group('win'):
                if prev_win:
                    if prev_win != curr_win:
                        self._error_exit("received conflicting win")
                    return
                else:
                    prev_win = curr_win
            elif prev_win:
                self._error_exit("received move; expected win")
            else:
                self._send_command(rmat.group('move'))
            msg = self._our_move()
            mat = re.match(MSG_PATN, msg)
            curr_win = mat.group('win')
            if curr_win:
                if prev_win:
                    if prev_win != curr_win:
                        self._error_exit("received conflicting win")
                    return
                else:
                    prev_win = curr_win
            elif prev_win:
                self._error_exit("made local move; expected win")

    def _win(self):
        self._time_remaining = SHORT_WAIT
        msg = self._our_move(True)
        if not WIN_PATN.match(msg):
            self._error_exit("expected win; received move")
        print(msg.rstrip(), file=self._output_dest)

    def _set_times(self, move_limit, game_limit):
        self._move_time_limit = move_limit
        self._game_time_limit = game_limit

    def _error_exit(self, reason):
        """Terminate this program and its opponent, if any, giving REASON
        as the end_message if there is none already.  Should only be called
        from the controller thread."""
        self._end_message = self._end_message or reason
        if self._other_prog:
            self._other_prog.receive_move(EOS)
            self._other_prog.stop()
        raise Terminate

    def receive_move(self, msg):
        if self._proc:
            self._enqueue(self._move_queue, msg)

    def stop(self, reason=None):
        """Terminate this program.  If there is no reason recorded for
        ending yet, use REASON."""
        self._end_message = self._end_message or reason
        no_except(lambda: self._proc.stdin.close())
        self._enqueue(self._proc_msg_queue, EOS)
        rc = None
        try:
            rc = self._proc.wait(timeout=SHORT_WAIT)
        except TimeoutExpired:
            no_except(lambda: self._proc.kill())
            try:
                rc = self._proc.wait(timeout=SHORT_WAIT)
            except TimeoutExpired:
                pass
        if rc != 0:
            self._end_message = \
                self._end_message or "process did not exit normallly."
        
    def _enqueue(self, queue, msg):
        try:
            queue.put(msg, block=False)
            return True
        except Full:
            self._end_message = \
                self._end_message \
                or "program {} seems to be unresponsive.".format(self._id)
            return False


    def _timed_get(self, queue, where):
        tries = int(min(self._move_time_limit, self._time_remaining) / QUANTUM)
        for c in range(tries):
            try:
                return queue.get(timeout=QUANTUM)
            except Empty:
                pass
            self._time_remaining -= QUANTUM
            if self._end_message is not None:
                raise Terminate
        else:
            self._end_message = \
                self._end_message or "time limit exceeded "  + where
        raise Terminate

    def _log(self, command, typ):
        if type(command) is bytes:
            command = command.decode(encoding='ascii', errors='ignore')
        if self._logging_queue is not None:
            self._logging_queue.put("{}{} {}".format(self._id, typ, command))
        
try:
    opts, args = getopt(sys.argv[1:], '', ['verbose'])
except GetoptError:
    Usage()

logger = None
for opt, val in opts:
    if opt == '--verbose':
        logger = Queue(100)
        queue_printer(logger, sys.stderr)

if len(args) == 1:
    in1 = open(args[0])
    in2 = None
elif len(args) == 2:
    in1 = open(args[0])
    in2 = open(args[1])
    base2 = splitext(args[1])[0]
else:
    Usage()
base1 = splitext(args[0])[0]
out1 = open(base1 + ".out", "w")
err1 = open(base1 + ".err", "w")

prog1 = Prog(get_run_command(in1, args[0]), "[1]", commands_in=in1,
             output_dest=out1, error_dest = err1, logging_queue=logger)

if in2:
    out2 = open(base2 + ".out", "w")
    err2 = open(base2 + ".err", "w")
    
    prog2 = Prog(get_run_command(in2, args[1]), "[2]", commands_in=in2,
                 output_dest=out2, error_dest = err2,
                 logging_queue=logger)

    if prog2.failed():
        prog1.stop()
    else:
        prog1.set_other(prog2)
        prog2.set_other(prog1)
        prog1.start()
        prog2.start()
else:
    prog2 = None
    prog1.start()

prog1.join()
if prog2:
    prog2.join()

if not prog1.end_message() and (not prog2 or not prog2.end_message()):
    sys.exit(0)
else:
    print(file=sys.stderr)
    if prog1.end_message():
        print("Program 1 ended with:", prog1.end_message(), file=err1)
    if prog2 and prog2.end_message():
        print("Program 2 ended with:", prog2.end_message(), file=err2)
    if prog1.end_message():
        print("Program 1 ended with:", prog1.end_message(), file=sys.stderr)
    else:
        print("Program 2 ended with:", prog2.end_message(), file=sys.stderr)

    sys.exit(1)
